// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// A binary wrapper for QuicClient.
// Connects to a host using QUIC, sends a request to the provided URL, and
// displays the response.
//
// Some usage examples:
//
//   TODO(rtenneti): make --host optional by getting IP Address of URL's host.
//
//   Get IP address of the www.google.com
//   IP=`dig www.google.com +short | head -1`
//
// Standard request/response:
//   quic_client http://www.google.com  --host=${IP}
//   quic_client http://www.google.com --quiet  --host=${IP}
//   quic_client https://www.google.com --port=443  --host=${IP}
//
// Use a specific version:
//   quic_client http://www.google.com --quic_version=23  --host=${IP}
//
// Send a POST instead of a GET:
//   quic_client http://www.google.com --body="this is a POST body" --host=${IP}
//
// Append additional headers to the request:
//   quic_client http://www.google.com  --host=${IP}
//               --headers="Header-A: 1234; Header-B: 5678"
//
// Connect to a host different to the URL being requested:
//   Get IP address of the www.google.com
//   IP=`dig www.google.com +short | head -1`
//   quic_client mail.google.com --host=${IP}
//
// Try to connect to a host which does not speak QUIC:
//   Get IP address of the www.example.com
//   IP=`dig www.example.com +short | head -1`
//   quic_client http://www.example.com --host=${IP}

#include <iostream>

#include "base/at_exit.h"
#include "base/command_line.h"
#include "base/logging.h"
#include "base/message_loop/message_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/privacy_mode.h"
#include "net/cert/cert_verifier.h"
#include "net/cert/multi_log_ct_verifier.h"
#include "net/http/transport_security_state.h"
#include "net/log/net_log.h"
#include "net/quic/crypto/proof_verifier_chromium.h"
#include "net/quic/quic_flags.h"
#include "net/quic/quic_protocol.h"
#include "net/quic/quic_server_id.h"
#include "net/quic/quic_utils.h"
#include "net/spdy/spdy_header_block.h"
#include "net/tools/epoll_server/epoll_server.h"
#include "net/tools/quic/quic_client.h"
#include "net/tools/quic/spdy_balsa_utils.h"
#include "net/tools/quic/synchronous_host_resolver.h"
#include "url/gurl.h"

using base::StringPiece;
using net::CertVerifier;
using net::CTPolicyEnforcer;
using net::CTVerifier;
using net::MultiLogCTVerifier;
using net::ProofVerifierChromium;
using net::TransportSecurityState;
using std::cerr;
using std::cout;
using std::endl;
using std::string;
using std::vector;

// The IP or hostname the quic client will connect to.
string FLAGS_host = "";
// The port to connect to.
int32_t FLAGS_port = 0;
// If set, send a POST with this body.
string FLAGS_body = "";
// If set, contents are converted from hex to ascii, before sending as body of
// a POST. e.g. --body_hex=\"68656c6c6f\"
string FLAGS_body_hex = "";
// A semicolon separated list of key:value pairs to add to request headers.
string FLAGS_headers = "";
// Set to true for a quieter output experience.
bool FLAGS_quiet = false;
// QUIC version to speak, e.g. 21. If not set, then all available versions are
// offered in the handshake.
int32_t FLAGS_quic_version = -1;
// If true, a version mismatch in the handshake is not considered a failure.
// Useful for probing a server to determine if it speaks any version of QUIC.
bool FLAGS_version_mismatch_ok = false;
// If true, an HTTP response code of 3xx is considered to be a successful
// response, otherwise a failure.
bool FLAGS_redirect_is_success = true;
// Initial MTU of the connection.
int32_t FLAGS_initial_mtu = 0;

class FakeCertVerifier : public net::CertVerifier {
public:
    int Verify(const RequestParams& params,
        net::CRLSet* crl_set,
        net::CertVerifyResult* verify_result,
        const net::CompletionCallback& callback,
        std::unique_ptr<Request>* out_req,
        const net::BoundNetLog& net_log) override
    {
        return net::OK;
    }

    // Returns true if this CertVerifier supports stapled OCSP responses.
    bool SupportsOCSPStapling() override { return false; }
};

int main(int argc, char* argv[])
{
    base::CommandLine::Init(argc, argv);
    base::CommandLine* line = base::CommandLine::ForCurrentProcess();
    const base::CommandLine::StringVector& urls = line->GetArgs();

    logging::LoggingSettings settings;
    settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG;
    CHECK(logging::InitLogging(settings));

    if (line->HasSwitch("h") || line->HasSwitch("help") || urls.empty()) {
        const char* help_str = "Usage: quic_client [options] <url>\n"
                               "\n"
                               "<url> with scheme must be provided (e.g. http://www.google.com)\n\n"
                               "Options:\n"
                               "-h, --help                  show this help message and exit\n"
                               "--host=<host>               specify the IP address of the hostname to "
                               "connect to\n"
                               "--port=<port>               specify the port to connect to\n"
                               "--body=<body>               specify the body to post\n"
                               "--body_hex=<body_hex>       specify the body_hex to be printed out\n"
                               "--headers=<headers>         specify a semicolon separated list of "
                               "key:value pairs to add to request headers\n"
                               "--quiet                     specify for a quieter output experience\n"
                               "--quic-version=<quic version> specify QUIC version to speak\n"
                               "--version_mismatch_ok       if specified a version mismatch in the "
                               "handshake is not considered a failure\n"
                               "--redirect_is_success       if specified an HTTP response code of 3xx "
                               "is considered to be a successful response, otherwise a failure\n"
                               "--initial_mtu=<initial_mtu> specify the initial MTU of the connection"
                               "\n"
                               "--disable-certificate-verification do not verify certificates\n";
        cout << help_str;
        exit(0);
    }
    if (line->HasSwitch("host")) {
        FLAGS_host = line->GetSwitchValueASCII("host");
    }
    if (line->HasSwitch("port")) {
        if (!base::StringToInt(line->GetSwitchValueASCII("port"), &FLAGS_port)) {
            std::cerr << "--port must be an integer\n";
            return 1;
        }
    }
    if (line->HasSwitch("body")) {
        FLAGS_body = line->GetSwitchValueASCII("body");
    }
    if (line->HasSwitch("body_hex")) {
        FLAGS_body_hex = line->GetSwitchValueASCII("body_hex");
    }
    if (line->HasSwitch("headers")) {
        FLAGS_headers = line->GetSwitchValueASCII("headers");
    }
    if (line->HasSwitch("quiet")) {
        FLAGS_quiet = true;
    }
    if (line->HasSwitch("quic-version")) {
        int quic_version;
        if (base::StringToInt(line->GetSwitchValueASCII("quic-version"),
                &quic_version)) {
            FLAGS_quic_version = quic_version;
        }
    }
    if (line->HasSwitch("version_mismatch_ok")) {
        FLAGS_version_mismatch_ok = true;
    }
    if (line->HasSwitch("redirect_is_success")) {
        FLAGS_redirect_is_success = true;
    }
    if (line->HasSwitch("initial_mtu")) {
        if (!base::StringToInt(line->GetSwitchValueASCII("initial_mtu"),
                &FLAGS_initial_mtu)) {
            std::cerr << "--initial_mtu must be an integer\n";
            return 1;
        }
    }

    VLOG(1) << "server host: " << FLAGS_host << " port: " << FLAGS_port
            << " body: " << FLAGS_body << " headers: " << FLAGS_headers
            << " quiet: " << FLAGS_quiet
            << " quic-version: " << FLAGS_quic_version
            << " version_mismatch_ok: " << FLAGS_version_mismatch_ok
            << " redirect_is_success: " << FLAGS_redirect_is_success
            << " initial_mtu: " << FLAGS_initial_mtu;

    base::AtExitManager exit_manager;
    base::MessageLoopForIO message_loop;

    // Determine IP address to connect to from supplied hostname.
    net::IPAddress ip_addr;

    GURL url(urls[0]);
    string host = FLAGS_host;
    if (host.empty()) {
        host = url.host();
    }
    int port = FLAGS_port;
    if (port == 0) {
        port = url.EffectiveIntPort();
    }
    if (!ip_addr.AssignFromIPLiteral(host)) {
        net::AddressList addresses;
        int rv = net::SynchronousHostResolver::Resolve(host, &addresses);
        if (rv != net::OK) {
            LOG(ERROR) << "Unable to resolve '" << host
                       << "' : " << net::ErrorToShortString(rv);
            return 1;
        }
        ip_addr = addresses[0].address();
    }

    string host_port = net::IPAddressToStringWithPort(ip_addr, port);
    VLOG(1) << "Resolved " << host << " to " << host_port << endl;

    // Build the client, and try to connect.
    net::EpollServer epoll_server;
    net::QuicServerId server_id(url.host(), url.EffectiveIntPort(),
        net::PRIVACY_MODE_DISABLED);
    net::QuicVersionVector versions = net::QuicSupportedVersions();
    if (FLAGS_quic_version != -1) {
        versions.clear();
        versions.push_back(static_cast<net::QuicVersion>(FLAGS_quic_version));
    }
    // For secure QUIC we need to verify the cert chain.
    std::unique_ptr<CertVerifier> cert_verifier(CertVerifier::CreateDefault());
    if (line->HasSwitch("disable-certificate-verification")) {
        cert_verifier.reset(new FakeCertVerifier());
    }
    std::unique_ptr<TransportSecurityState> transport_security_state(
        new TransportSecurityState);
    transport_security_state.reset(new TransportSecurityState);
    std::unique_ptr<CTVerifier> ct_verifier(new MultiLogCTVerifier());
    std::unique_ptr<CTPolicyEnforcer> ct_policy_enforcer(new CTPolicyEnforcer());
    ProofVerifierChromium* proof_verifier = new ProofVerifierChromium(
        cert_verifier.get(), ct_policy_enforcer.get(),
        transport_security_state.get(), ct_verifier.get());
    net::QuicClient client(net::IPEndPoint(ip_addr, FLAGS_port), server_id,
        versions, &epoll_server, proof_verifier);
    client.set_initial_max_packet_length(
        FLAGS_initial_mtu != 0 ? FLAGS_initial_mtu : net::kDefaultMaxPacketSize);
    if (!client.Initialize()) {
        cerr << "Failed to initialize client." << endl;
        return 1;
    }
    if (!client.Connect()) {
        net::QuicErrorCode error = client.session()->error();
        if (FLAGS_version_mismatch_ok && error == net::QUIC_INVALID_VERSION) {
            cout << "Server talks QUIC, but none of the versions supported by "
                 << "this client: " << QuicVersionVectorToString(versions) << endl;
            // Version mismatch is not deemed a failure.
            return 0;
        }
        cerr << "Failed to connect to " << host_port
             << ". Error: " << net::QuicUtils::ErrorToString(error) << endl;
        return 1;
    }
    cout << "Connected to " << host_port << endl;

    // Construct the string body from flags, if provided.
    string body = FLAGS_body;
    if (!FLAGS_body_hex.empty()) {
        DCHECK(FLAGS_body.empty()) << "Only set one of --body and --body_hex.";
        body = net::QuicUtils::HexDecode(FLAGS_body_hex);
    }

    // Construct a GET or POST request for supplied URL.
    net::BalsaHeaders headers;
    headers.SetRequestFirstlineFromStringPieces(body.empty() ? "GET" : "POST",
        url.spec(), "HTTP/1.1");

    // Append any additional headers supplied on the command line.
    for (const std::string& header :
        base::SplitString(FLAGS_headers, ";", base::KEEP_WHITESPACE,
            base::SPLIT_WANT_NONEMPTY)) {
        string sp;
        base::TrimWhitespaceASCII(header, base::TRIM_ALL, &sp);
        if (sp.empty()) {
            continue;
        }
        vector<string> kv = base::SplitString(sp, ":", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL);
        CHECK_EQ(2u, kv.size());
        string key;
        base::TrimWhitespaceASCII(kv[0], base::TRIM_ALL, &key);
        string value;
        base::TrimWhitespaceASCII(kv[1], base::TRIM_ALL, &value);
        headers.AppendHeader(key, value);
    }

    // Make sure to store the response, for later output.
    client.set_store_response(true);

    // Send the request.
    net::SpdyHeaderBlock header_block = net::SpdyBalsaUtils::RequestHeadersToSpdyHeaders(headers);
    client.SendRequestAndWaitForResponse(headers, body, /*fin=*/true);

    // Print request and response details.
    if (!FLAGS_quiet) {
        cout << "Request:" << endl;
        cout << "headers:" << header_block.DebugString();
        if (!FLAGS_body_hex.empty()) {
            // Print the user provided hex, rather than binary body.
            cout << "body hex:   " << FLAGS_body_hex << endl;
            cout << "body ascii: " << net::QuicUtils::BinaryToAscii(net::QuicUtils::HexDecode(FLAGS_body_hex))
                 << endl;
        } else {
            cout << "body: " << body << endl;
        }
        cout << endl;
        cout << "Response:" << endl;
        cout << "headers: " << client.latest_response_headers() << endl;
        string response_body = client.latest_response_body();
        if (!FLAGS_body_hex.empty()) {
            // Assume response is binary data.
            cout << "body hex:   " << net::QuicUtils::HexEncode(response_body)
                 << endl;
            cout << "body ascii: " << net::QuicUtils::BinaryToAscii(response_body)
                 << endl;
        } else {
            cout << "body: " << response_body << endl;
        }
        cout << "trailers: " << client.latest_response_trailers() << endl;
    }

    size_t response_code = client.latest_response_code();
    if (response_code >= 200 && response_code < 300) {
        cout << "Request succeeded (" << response_code << ")." << endl;
        return 0;
    } else if (response_code >= 300 && response_code < 400) {
        if (FLAGS_redirect_is_success) {
            cout << "Request succeeded (redirect " << response_code << ")." << endl;
            return 0;
        } else {
            cout << "Request failed (redirect " << response_code << ")." << endl;
            return 1;
        }
    } else {
        cerr << "Request failed (" << response_code << ")." << endl;
        return 1;
    }
}
